-
Notifications
You must be signed in to change notification settings - Fork 162
remove DetachedEagleGPT model and handle all offline mode in the _DynamicEagleGPTModel #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRemoved the OfflineEagleDMRegistry and detached wrappers. Unified registration and conversion through EagleDMRegistry. Updated conversion to dynamically map subclasses and pass expanded Eagle config. In Megatron plugin, added explicit offline (eagle_offline) forward path, shape/indexing adjustments, and disabled sequence_parallel offline. HF transformers detached wrapper removed. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant App
participant Conversion as conversion.py
participant Registry as EagleDMRegistry
participant EagleModel
User->>App: convert_to_eagle_model(model, config)
App->>Conversion: convert_to_eagle_model(model, config)
alt Type registered
Conversion->>Registry: lookup type(model)
else Type not registered
Conversion->>Registry: iterate _registry for subclass match
Note over Conversion,Registry: On match, map original class to base via register()
end
Conversion->>Registry: convert(model)
Registry-->>Conversion: eagle_model
Conversion->>EagleModel: modify(config: offline, hidden_state_distillation, self_logit_distillation, freeze_base_model, report_acc, reuse_base_decoder, loss_decay_factor, architecture_config)
EagleModel-->>App: modified eagle_model
sequenceDiagram
autonumber
actor Trainer
participant Model as MegatronEagleModel
participant Base as BaseModel
participant Eagle as EagleModule
Trainer->>Model: forward(input_ids, labels, kwargs)
alt Online
Model->>Base: _base_model_forward(...)
Base-->>Model: hidden_states
opt return_eagle_inputs
Model->>Model: _get_eagle_input_hidden_states(hidden_states, apply_fc=false)
Model-->>Trainer: {input_ids, aux_hidden_states, hidden_states}
end
Model->>Model: _get_eagle_input_hidden_states(hidden_states, apply_fc=true)
else Offline
Note over Model: sequence_parallel disabled
Model->>Model: use aux_hidden_states from kwargs
Model->>Model: _get_eagle_input_hidden_states(aux_hidden_states, apply_fc=true)
Note over Model: If labels len = input_ids-1, pad labels
end
Model->>Eagle: compute logits (multi-step)
Note over Model,Eagle: Shape-based slicing for eagle_logits_N in offline
Eagle-->>Trainer: outputs (logits, losses, etc.)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
1-1
: Replace or remove OfflineEagleDMRegistry usages in transformers.pyFound offline-registry references: import at modelopt/torch/speculative/plugins/transformers.py:49 and decorator usage @OfflineEagleDMRegistry.register at line 1144. No matches for DetachedEagle|DetachedEagleGPT or for return_eagle_inputs=True were found. Replace these uses with the supported registry API (e.g., EagleDMRegistry) or remove/update them to match the intended removal of the offline registry.
🧹 Nitpick comments (8)
modelopt/torch/speculative/eagle/conversion.py (1)
35-39
: Avoid relying on private registry internals; register with a stable key and fail fast when unmapped.
- Using EagleDMRegistry._registry directly is a private detail and brittle.
- The key "base_model_class" is opaque; prefer a stable, informative key (e.g., fully-qualified class name).
- If no base class match is found, EagleDMRegistry.convert will raise a KeyError later with less context. Raise a clear error when the mapping cannot be inferred.
Apply this diff to make the registration more robust and explicit:
- if original_cls not in EagleDMRegistry: - for cls in EagleDMRegistry._registry: - if issubclass(original_cls, cls): - EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) - break + if original_cls not in EagleDMRegistry: + # Register a mapping for subclasses to the same dynamic class as their base. + found = False + for base_cls in list(EagleDMRegistry._registry.keys()): + if issubclass(original_cls, base_cls): + key = f"{original_cls.__module__}.{original_cls.__name__}" + EagleDMRegistry.register({original_cls: key})(EagleDMRegistry[base_cls]) + found = True + break + if not found: + raise KeyError( + f"No Eagle dynamic mapping for {original_cls.__module__}.{original_cls.__name__}. " + f"Ensure a compatible base class is registered in EagleDMRegistry." + )modelopt/torch/speculative/plugins/megatron_eagle.py (7)
748-751
: Disable sequence_parallel in offline mode — add a guard log.For clarity during debugging, emit a one-time info when flipping sequence_parallel to False in offline runs so users aren’t surprised by the override.
Apply this diff:
# sequence_parallel is not used in offline eagle if self.eagle_offline: - self.config.sequence_parallel = False + if self.config.sequence_parallel: + warnings.warn("EAGLE offline: forcing sequence_parallel = False") + self.config.sequence_parallel = False
768-768
: Update comment to reflect “offline” (detached class no longer exists).Replace “detached eagle” with “offline eagle” to avoid confusion.
- # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier + # Layer IDs are not used in offline eagle, but we set them to get the correct fc_input_size_multiplier
853-866
: Offline path: assert input shape assumptions and document expectations.In offline mode, _get_eagle_input_hidden_states expects aux_hidden_states to already be concatenated ([s, b, k*h]). Add explicit asserts to catch silent shape mismatches.
def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): """When _aux_hidden_states is not empty, then this is EAGLE-3. @@ - if not self.eagle_offline: + if not self.eagle_offline: if len(self._aux_hidden_states) == 0: return hidden_states @@ - if apply_fc: + if apply_fc: + # In offline mode, hidden_states may already be [s, b, k*h]; validate k. + if self.eagle_offline: + h = self.config.hidden_size + assert hidden_states.shape[-1] % h == 0, ( + f"Expected aux_hidden_states hidden dim to be a multiple of {h}, " + f"got {hidden_states.shape[-1]}" + ) # [s / TP, b, 3h] -> [s / TP, b, h] return self.eagle_module.fc(hidden_states)[0]
1076-1086
: Make loss alignment explicit and robust across online/offline lengths.Inferring offline via labels.shape[1] < eagle_logits.shape[0] is brittle. Align by checking the delta and fail fast on unexpected shapes.
if self.eagle_self_logit_distillation: mapping = self.eagle_module.d2t if hasattr(self.eagle_module, "d2t") else None token_loss = self.kld(eagle_logits[:-1, :, :], logits[1:, :, :], mapping) - elif labels.shape[1] < eagle_logits.shape[0]: - token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-2, :, :]) - else: - token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-1, :, :]) + else: + target_len = labels.shape[1] - 1 + pred_len = eagle_logits.shape[0] + if pred_len - 1 == target_len: + aligned = eagle_logits[:-1, :, :] + elif pred_len - 2 == target_len: + aligned = eagle_logits[:-2, :, :] + else: + raise ValueError( + f"Unexpected lengths for EAGLE loss: labels={labels.shape}, " + f"eagle_logits={eagle_logits.shape}" + ) + token_loss = self.compute_language_model_loss(labels[:, 1:], aligned)
1298-1301
: Consolidate accuracy-slice logic to avoid off‑by‑one drift.Four places differ only by the starting offset and offline end trim (-1 vs -2). Centralize this to a small helper to prevent future inconsistencies.
Add this helper (outside forward):
def _slice_for_acc(self, logits_1d: torch.Tensor, start: int) -> torch.Tensor: end = -2 if self.eagle_offline else -1 return logits_1d[start:end, :, :]Then replace:
- eagle_logits_0[:-2 or :-1] -> self._slice_for_acc(eagle_logits_0, 0)
- eagle_logits_1[1:-2 or 1:-1] -> self._slice_for_acc(eagle_logits_1, 1)
- eagle_logits_2[2:-2 or 2:-1] -> self._slice_for_acc(eagle_logits_2, 2)
- eagle_logits_3[3:-2 or 3:-1] -> self._slice_for_acc(eagle_logits_3, 3)
Also applies to: 1337-1343, 1379-1385, 1421-1426
1245-1249
: Avoid eager .cpu() copies on hot paths.If return_eagle_inputs is used frequently, consider non_blocking transfers to CPU to minimize stalls.
- return { - "input_ids": input_ids.squeeze(0).cpu(), - "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu(), - "hidden_states": hidden_states.squeeze(1).cpu(), - } + return { + "input_ids": input_ids.squeeze(0).to("cpu", non_blocking=True), + "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).to( + "cpu", non_blocking=True + ), + "hidden_states": hidden_states.squeeze(1).to("cpu", non_blocking=True), + }
1461-1473
: Minor: clear hooks along early returns.You clear self._aux_hidden_states on early-return with no children. If forward hooks were registered conditionally, also consider guarding to avoid accumulating state across calls.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/speculative/eagle/conversion.py
(1 hunks)modelopt/torch/speculative/plugins/megatron_eagle.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/speculative/eagle/conversion.py (1)
modelopt/torch/opt/dynamic.py (4)
original_cls
(867-873)register
(1069-1096)convert
(613-672)convert
(1117-1119)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (5)
modelopt/torch/speculative/eagle/conversion.py (1)
41-41
: LGTM: unified conversion path through EagleDMRegistry.Directly converting via EagleDMRegistry.convert(model) is consistent with the registry consolidation effort.
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
55-55
: LGTM: remove OfflineEagleDMRegistry import.Import surface now matches the consolidated registry usage.
818-820
: LGTM: propagate draft_vocab_size and has_lm_head into reused config.Keeps EAGLE module output head consistent in reuse path.
1323-1331
: LGTM: shape-based splitting of multi-step logits.Indexing with logits_sbh.shape[0] makes the slicing resilient to seq-length variations.
Also applies to: 1364-1373, 1406-1415
1623-1653
: LGTM: offline/online flows preserved in pseudo_speculative_generate.No issues spotted; consistent with the new offline gating and SP handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
1196-1211
: Addressed prior review: offline forward now validates inputs and forbids return_eagle_inputs.This resolves the earlier “NoneType later” hazard and ambiguous semantics in offline mode.
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
748-751
: Disable sequence_parallel in offline: good; add an explicit user-facing warning when flipping it.Silent mutation of config can surprise callers. Emit a warning if it was True.
Apply:
# sequence_parallel is not used in offline eagle if self.eagle_offline: - self.config.sequence_parallel = False + if self.config.sequence_parallel: + warnings.warn("EAGLE offline mode: forcibly disabling sequence_parallel.", stacklevel=1) + self.config.sequence_parallel = False
767-768
: Update terminology: “detached” → “offline”.Comment still mentions “detached eagle”; recommend aligning to “offline” to avoid confusion.
- # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier + # layer ids are not used in offline EAGLE, but we need to set this to have correct fc_input_size_multiplier
853-866
: Offline/path gate in _get_eagle_input_hidden_states: OK. Consider defensive clearing.Looks correct. As a safety edge-case when toggling modes within a long-lived object, consider clearing _aux_hidden_states on offline entry to avoid accidental reuse.
def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): @@ - if not self.eagle_offline: + if not self.eagle_offline: if len(self._aux_hidden_states) == 0: return hidden_states # [s / TP, b, len(self._aux_hidden_states) * h] hidden_states = torch.cat(self._aux_hidden_states, dim=-1) self._aux_hidden_states.clear() + else: + # Ensure no stale states leak in if mode was toggled. + self._aux_hidden_states.clear()
1083-1086
: Tighten condition for offline label/logit alignment.Use equality instead of less-than to avoid mis-slicing when labels are unexpectedly shorter.
- elif labels.shape[1] < eagle_logits.shape[0]: + elif labels.shape[1] + 1 == eagle_logits.shape[0]: token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-2, :, :])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
modelopt/torch/opt/dynamic.py (2)
config
(1265-1277)get
(1013-1038)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/torch/speculative/plugins/megatron_eagle.py (5)
55-55
: Registry consolidation LGTM.Removing OfflineEagleDMRegistry import and relying on EagleDMRegistry is consistent with the PR goal.
818-820
: Propagating draft_vocab_size and has_lm_head when reusing base decoder: LGTM.Ensures output layer wiring matches reduced vocab configs.
1076-1077
: Docstring update acknowledging offline label shapes: LGTM.
1292-1294
: Shape-based slicing for multi-step chunks: LGTM.Indexing using logits_sbh.shape[0] is more robust than relying on labels.
Also applies to: 1337-1338, 1379-1379, 1421-1421
1307-1308
: Offline Top‑1 slice offsets: guard for short seq_len':-2' (offline) yields zero-length spans for seq_len ≤ 2 (':-1' yields zero for seq_len ≤ 1). Clamp or condition the slice endpoints (e.g., use max(seq_len-2, 0) or skip/adjust metric computation) to avoid empty denominators.
Applies to modelopt/torch/speculative/plugins/megatron_eagle.py — lines 1307–1308, 1348–1349, 1390–1391, 1432–1433.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #321 +/- ##
==========================================
- Coverage 73.82% 73.81% -0.01%
==========================================
Files 172 172
Lines 17438 17436 -2
==========================================
- Hits 12874 12871 -3
- Misses 4564 4565 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
modelopt/torch/speculative/plugins/transformers.py (4)
561-576
: Bug: dtype mismatch (bool vs float) when concatenating attention masks.
torch.cat
overattention_mask_0
(float) andzero_mask
/mask_2_2
(bool) will error. Keep masks in the same floating dtype and use 0/1 sentinels before masked_fill.Apply:
- zero_mask = torch.ones_like(attention_mask_0).bool() + zero_mask = torch.ones_like(attention_mask_0) mask_2_1 = attention_mask_0.clone().detach() mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() + mask_2_2 = torch.ones_like(attention_mask_0) for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False + mask_2_2[:, :, i, i] = 0 cat_attention_mask = torch.cat( ( torch.cat((attention_mask_0, zero_mask), dim=-1), torch.cat((mask_2_1, mask_2_2), dim=-1), ), dim=-2, ) cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin)
593-617
: Same dtype bug in the 3-block concat path (second step).Ensure all masks are float tensors; use 0/1 sentinels consistently.
- zero_mask = torch.ones_like(attention_mask_0).bool() + zero_mask = torch.ones_like(attention_mask_0) mask_2_1 = attention_mask_0.clone().detach() mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() + mask_2_2 = torch.ones_like(attention_mask_0) for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False + mask_2_2[:, :, i, i] = 0 mask_3_1 = mask_2_1.clone().detach() mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] mask_3_2 = mask_2_2.clone().detach() mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True + mask_3_2[:, :, 1, 0] = 1 mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True + mask_3_3[:, :, 1, 1] = 1
633-671
: Same dtype bug in the 4-block concat path (third step).Fix bool/float mixing; keep sentinels numeric.
- zero_mask = torch.ones_like(attention_mask_0).bool() + zero_mask = torch.ones_like(attention_mask_0) mask_2_1 = attention_mask_0.clone().detach() mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() + mask_2_2 = torch.ones_like(attention_mask_0) for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False + mask_2_2[:, :, i, i] = 0 mask_3_1 = mask_2_1.clone().detach() mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] mask_3_2 = mask_2_2.clone().detach() mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True + mask_3_2[:, :, 1, 0] = 1 mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True + mask_3_3[:, :, 1, 1] = 1 mask_4_1 = mask_3_1.clone().detach() mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] mask_4_2 = mask_3_2.clone().detach() mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True + mask_4_2[:, :, 2, 0] = 1 mask_4_3 = mask_3_3.clone().detach() mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True + mask_4_3[:, :, 2, 1] = 1 mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True + mask_4_4[:, :, 2, 2] = 1
120-132
: Parameter name mismatch — use cache_position, not rcache_positiontransformers.py: forward(...) defines cache_position (lines 107 and 763) but the call at line 130 passes rcache_position=cache_position; rcache_position appears nowhere else — replace with cache_position=cache_position or verify the wrapped model actually expects rcache_position.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
692-701
: Use inference_mode for read-only blocks.Replace
torch.no_grad()
withtorch.inference_mode()
where no tensors require grad to reduce dispatcher overhead and enable viewless inference.Example:
- with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): + with torch.inference_mode() if freeze_base_model else contextlib.nullcontext():And similarly for the later embed-only sections.
Also applies to: 842-853, 874-888, 922-941, 965-984
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
49-49
: Registry consolidation LGTM — verify no leftover Offline registry usages.File: modelopt/torch/speculative/plugins/transformers.py — import now:
from ..eagle.conversion import EagleDMRegistry
The rg search you ran returned no output; absence of matches is inconclusive — re-run locally/CI:
rg -nP -C2 '\b(OfflineEagleDMRegistry|DetachedEagleGPTModel|DetachedHFEagleModel)\b' || true rg -n -S -C2 'OfflineEagle|DetachedEagle|eagle.conversion|EagleDMRegistry' || true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
1199-1211
: Fix offline inputs: allow EAGLE‑1, validate args, and centralize PP early‑return.Offline forward still requires aux_hidden_states and places the PP early‑return only in the online branch. This blocks EAGLE‑1 offline and can NPE when aux is absent. Move the PP early‑return out of the if/else and only require aux when use_aux_hidden_state=True. Also select src_states accordingly.
Apply this diff:
@@ - if self.eagle_offline: - # aux_hidden_states and hidden_states are provided for offline eagle - # _base_model_forward is skipped - if return_eagle_inputs: - raise ValueError("return_eagle_inputs is unsupported in EAGLE offline mode.") - aux_hidden_states = kwargs.get("aux_hidden_states") - hidden_states = kwargs.get("hidden_states") - if aux_hidden_states is None or hidden_states is None: - raise ValueError( - "EAGLE offline mode requires kwargs: aux_hidden_states=[s,b,k*h], " - "hidden_states=[s,b,h]." - ) + if self.eagle_offline: + # hidden_states (required) and aux_hidden_states (optional for EAGLE‑1) are provided for offline EAGLE. + # _base_model_forward is skipped + if return_eagle_inputs: + raise ValueError("return_eagle_inputs is unsupported in EAGLE offline mode.") + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None: + raise ValueError("EAGLE offline mode requires kwargs: hidden_states=[s,b,h].") + aux_hidden_states = kwargs.get("aux_hidden_states", None) + if self.eagle_config.use_aux_hidden_state and aux_hidden_states is None: + raise ValueError( + "EAGLE‑3 offline additionally requires aux_hidden_states=[s,b,k*h] when " + "use_aux_hidden_state=True." + ) else: @@ - # Typically, this is only the case when PP > 1. - if not self.post_process: - return hidden_states + # Typically, this is only the case when PP > 1. + if not self.post_process: + return hidden_states @@ - if self.eagle_offline: - eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( - aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state - ) + if self.eagle_offline: + src_states = ( + aux_hidden_states if self.eagle_config.use_aux_hidden_state else hidden_states + ) + eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( + src_states, apply_fc=self.eagle_config.use_aux_hidden_state + )Also applies to: 1226-1229, 1235-1238
🧹 Nitpick comments (6)
modelopt/torch/speculative/plugins/megatron_eagle.py (6)
1241-1256
: Preserve batch dims for return_eagle_inputs (avoid squeeze assumptions).Squeezing dim 1 breaks when batch_size > 1. Return tensors with their batch dims intact.
- return { - "input_ids": input_ids.squeeze(0).cpu(), - "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu(), - "hidden_states": hidden_states.squeeze(1).cpu(), - } + return { + "input_ids": input_ids.cpu(), + "aux_hidden_states": eagle_module_input_hidden_states.cpu(), + "hidden_states": hidden_states.cpu(), + }
767-773
: Update comment: “detached eagle” wording is outdated.This PR removes DetachedEagle; update the comment to reflect “offline EAGLE” semantics.
- # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier + # Layer IDs are not used in offline EAGLE‑1, but we set them for correct fc_input_size_multiplier in EAGLE‑3.
1083-1085
: Make offline loss branch explicit (avoid shape‑heuristic).Using labels.shape[1] < eagle_logits.shape[0] is brittle. Gate by eagle_offline and optionally assert expected alignment.
- elif labels.shape[1] < eagle_logits.shape[0]: + elif self.eagle_offline: + assert eagle_logits.shape[0] >= labels.shape[1] + 1, \ + "Offline EAGLE expects logits length >= labels length + 1." token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-2, :, :])
1309-1311
: DRY the offline/online top‑1 slicing logic.The repeated ternaries ([:-2] vs. [:-1]) are easy to drift. Factor a small helper to slice logits consistently.
+ def _slice_for_top1(self, logits): + return logits[:-2, :, :] if self.eagle_offline else logits[:-1, :, :] @@ - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-2, :, :] if self.eagle_offline else eagle_logits_0[:-1, :, :] - ) + gathered_logits = gather_from_tensor_model_parallel_region( + self._slice_for_top1(eagle_logits_0) + ) @@ - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[1:-2, :, :] if self.eagle_offline else eagle_logits_1[1:-1, :, :] - ) + gathered_logits = gather_from_tensor_model_parallel_region( + self._slice_for_top1(eagle_logits_1[1:, :, :]) + ) @@ - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[2:-2, :, :] if self.eagle_offline else eagle_logits_2[2:-1, :, :] - ) + gathered_logits = gather_from_tensor_model_parallel_region( + self._slice_for_top1(eagle_logits_2[2:, :, :]) + ) @@ - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[3:-2, :, :] if self.eagle_offline else eagle_logits_3[3:-1, :, :] - ) + gathered_logits = gather_from_tensor_model_parallel_region( + self._slice_for_top1(eagle_logits_3[3:, :, :]) + )Also applies to: 1350-1351, 1392-1393, 1434-1435
846-866
: Online/Offline hidden state selection looks correct; small doc tweak.Logic to bypass aux accumulation in offline is right. Consider noting in the docstring that offline passes either aux (EAGLE‑3) or base hidden_states (EAGLE‑1).
1251-1251
: Optional: avoid unnecessary full TP gather of base logits on return_eagle_inputs.If the only consumer is offline precompute, gathering logits_sbh here is unnecessary. Returning CPU’d inputs without TP gather can reduce overhead.
If you keep it, please confirm a downstream consumer requires full‑vocab logits. Otherwise, drop this gather.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
modelopt/torch/opt/dynamic.py (2)
config
(1265-1277)get
(1013-1038)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
141-144
: Doc vs. implementation mismatch: “at least” should be >= not >.[ suggest_nitpick_refactor ]
- """Check if megatron-core is least this version.""" - return Version(megatron.core.__version__) > Version(target_version) + """Check if megatron-core is at least this version.""" + return Version(megatron.core.__version__) >= Version(target_version)
818-820
: Good: propagate draft_vocab_size/has_lm_head when reusing base decoder.This keeps EagleModule consistent with the reduced draft vocab. 👍
748-751
: Disabling sequence_parallel in offline: confirm interaction with checkpoints and SP‑trained bases.ripgrep shows gather_from_sequence_parallel_region / scatter_to_sequence_parallel_region are used throughout megatron_eagle.py but guarded by if self.config.sequence_parallel (examples: ~lines 901–904, 941–943, 961–964, 1245–1250, 1510–1511, 1547–1549, 1670–1671) — so forcing sequence_parallel=False will skip those calls, but this alone doesn't guarantee safety.
- Verify offline runs do not load SP‑sharded checkpoints or per‑rank weight shards (search for load_state_dict / load_checkpoint / state_dict / tensor_model_parallel loader code).
- Ensure no tensor‑sharded buffers are referenced unguarded during offline init; add an explicit assertion or a conversion/loader branch when eagle_offline=True to fail fast if an SP checkpoint is attempted to be loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
1199-1211
: Offline EAGLE-1 is blocked; make aux_hidden_states optional and validate by config.Requiring aux_hidden_states unconditionally prevents offline EAGLE‑1. Gate the requirement on use_aux_hidden_state and always require hidden_states. Also reject return_eagle_inputs early (kept).
Apply:
if self.eagle_offline: # aux_hidden_states and hidden_states are provided for offline eagle # _base_model_forward is skipped if return_eagle_inputs: raise ValueError("return_eagle_inputs is unsupported in EAGLE offline mode.") - aux_hidden_states = kwargs.get("aux_hidden_states") - hidden_states = kwargs.get("hidden_states") - if aux_hidden_states is None or hidden_states is None: - raise ValueError( - "EAGLE offline mode requires kwargs: aux_hidden_states=[s,b,k*h], " - "hidden_states=[s,b,h]." - ) + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None: + raise ValueError("EAGLE offline mode requires kwargs: hidden_states=[s,b,h].") + aux_hidden_states = kwargs.get("aux_hidden_states", None) + if self.eagle_config.use_aux_hidden_state and aux_hidden_states is None: + raise ValueError( + "EAGLE‑3 offline additionally requires aux_hidden_states=[s,b,k*h] when " + "use_aux_hidden_state=True." + )
1226-1228
: Centralize PP early‑return so it runs for both online and offline paths.Right now the early return is only reachable in the online branch. Move it out so offline PP does not proceed to output/loss.
- # Typically, this is only the case when PP > 1. - if not self.post_process: - return hidden_states + # Typically, this is only the case when PP > 1. + if not self.post_process: + return hidden_states
1235-1238
: Offline EAGLE input should come from aux or base hidden by config.Use aux when use_aux_hidden_state=True (EAGLE‑3), otherwise pass hidden_states (EAGLE‑1). Also set apply_fc accordingly.
- if self.eagle_offline: - eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( - aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state - ) + if self.eagle_offline: + src_states = ( + aux_hidden_states + if self.eagle_config.use_aux_hidden_state + else hidden_states + ) + eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( + src_states, apply_fc=self.eagle_config.use_aux_hidden_state + )
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (4)
1306-1321
: Guard accuracy reporting against empty slices (short sequences).When seq_len is too short, slices like[:-2] can be empty → argmax and division by zero crash. Skip metrics if there is no token to score.
Apply to each block (1st→4th) similarly; example shown for the 1st block:
- with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-2, :, :] if self.eagle_offline else eagle_logits_0[:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) + with torch.no_grad(): + sl = eagle_logits_0.shape[0] - (2 if self.eagle_offline else 1) + if sl > 0: + logits_slice = ( + eagle_logits_0[:sl, :, :] + ) + gathered_logits = gather_from_tensor_model_parallel_region(logits_slice) + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + denom = max(1, eagle_top1.numel()) + top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / denom + acc.append(top1_p)Repeat the same shape guard for the 2nd, 3rd, and 4th accuracy blocks.
Also applies to: 1348-1363, 1390-1405, 1432-1447
1076-1086
: Loss shape handling: prefer explicit offline flag over shape heuristics.Using labels.shape[1] < eagle_logits.shape[0] to infer offline mode is brittle. Gate the[:-2] slice with self.eagle_offline for clarity and safety.
- elif labels.shape[1] < eagle_logits.shape[0]: - token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-2, :, :]) - else: - token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-1, :, :]) + else: + slice_ = eagle_logits[:-2, :, :] if self.eagle_offline else eagle_logits[:-1, :, :] + token_loss = self.compute_language_model_loss(labels[:, 1:], slice_)
768-771
: Stale terminology: “detached eagle” in comment.Detached wrapper was removed; update the comment to “offline EAGLE”.
- # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier + # Layer ids are not used in offline EAGLE, but we set them to derive the correct fc_input_size_multiplier.
640-647
: Comment contradicts code: clarify EAGLE‑1/‑3 inputs.Code concatenates [embeddings, hidden_states] → 2h then fc for EAGLE‑1; EAGLE‑3 uses aux + specialized attention to reduce 2h→h. Update the comment.
- # EAGLE-1 uses [s, b, h] input but EAGLE-3 uses [s, b, 2h] input + # EAGLE‑1: concat [embeddings, hidden_states] → [s,b,2h], then fc → [s,b,h]. + # EAGLE‑3: consumes aux/hidden; first attention reduces 2h → h (specialized qkv).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
modelopt/torch/opt/dynamic.py (2)
config
(1265-1277)get
(1013-1038)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
748-751
: Disabling sequence_parallel in offline mode looks correct.This avoids ColumnParallelLinear sequence gathers that don't apply to precomputed inputs.
Confirm there’s no other SP‑dependent op executed in offline path (e.g., gathering/scattering around Eagle inputs).
818-820
: Draft‑vocab and has_lm_head propagation: LGTM.Keeps EAGLE module output layer consistent with reduced vocab scenarios.
Need to run the regression test to see if any impact to the megatron online training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/export/plugins/megatron_importer.py (1)
515-521
: Harden softmax_offset gating to avoid AttributeError/KeyErrorGood tighten on the None check. However,
hasattr(attention.core_attention, ...)
still dereferencesattention.core_attention
unconditionally and will raise ifattention
lacks that attribute. Also, indexingself.rules["softmax_offset"]
can KeyError when the rule isn’t present for some arch. Guard both.Apply this diff:
- if ( - hasattr(attention.core_attention, "softmax_offset") - and attention.core_attention.softmax_offset is not None - ): - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) + core_attn = getattr(attention, "core_attention", None) + softmax_offset = getattr(core_attn, "softmax_offset", None) + softmax_rule = self.rules.get("softmax_offset") + if softmax_offset is not None and callable(softmax_rule): + softmax_rule(softmax_offset, layer_id)Please confirm all non-MLA attention variants hit here always have
core_attention
(or this guard will be needed). Also confirm the rule exists for arches wheresoftmax_offset
is expected (no unexpected skips).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/plugins/megatron_importer.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Kinjal Patel <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: realAsma <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: Ye Yu <[email protected]>
…specific unitests; (#318) Signed-off-by: realAsma <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
This reverts commit 9450e0d. Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
1194-1208
: Validate offline mode inputs and forbid return_eagle_inputs.The offline forward path correctly validates required inputs but there's a potential issue with the validation logic.
This matches the existing comment about validating required inputs and forbidding return_eagle_inputs in offline mode.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
767-767
: Clarify the comment about offline mode.The comment mentions that "layer ids are not used in offline eagle" but this could be more precise. EAGLE-3 still uses layer IDs internally for determining the FC input size multiplier, even in offline mode.
Apply this diff to improve the comment clarity:
- # layer ids are not used in offline eagle, but we need to set this to have correct fc_input_size_multiplier + # layer ids are not actively extracted in offline eagle, but we still need to set this to have correct fc_input_size_multiplier
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
modelopt/torch/opt/dynamic.py (2)
config
(1265-1278)get
(1013-1038)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/torch/speculative/plugins/megatron_eagle.py (8)
818-820
: Configuration parameters propagated correctly.The addition of
draft_vocab_size
andhas_lm_head
parameters to the eagle_config ensures proper configuration of the EagleModule for both online and offline modes.
847-866
: Offline/online mode handling in input hidden states.The logic correctly handles both offline and online modes. For offline mode, it uses the provided aux_hidden_states directly; for online mode, it concatenates the aux_hidden_states collected from forward hooks.
1209-1222
: Online forward path preserved correctly.The base model forward path is correctly preserved for online mode with proper parameter passing to
_base_model_forward
.
1224-1226
: Pipeline parallel early return centralized.The early return for non-post-process pipeline stages is now properly centralized and applies to both online and offline modes, which is correct.
1233-1236
: Offline eagle input processing.The offline mode correctly processes aux_hidden_states using
_get_eagle_input_hidden_states
with the appropriateapply_fc
parameter based onself.eagle_config.use_aux_hidden_state
.
1238-1254
: Return eagle inputs logic preserved.The logic for returning eagle inputs in online mode is preserved and correctly returns the expected dictionary format with input_ids, aux_hidden_states, and hidden_states.
1282-1293
: Labels padding for offline mode.The logic handles the case where labels may be one token shorter than input_ids in offline mode by padding with zeros. This is a reasonable approach for offline training scenarios.
1344-1344
: No change required — labels-based slicing is correct.The code pads offline labels when labels.shape[1] == input_ids.shape[1] - 1 (megatron_eagle.py ~1282–1286) and uses labels.shape[1] consistently for eagle_logits slicing across drafts (lines ~1301, 1344, 1386, 1428), so the dynamic slicing handles both padded and unpadded cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
1197-1208
: ****Offline forward path: validate required inputs and forbid return_eagle_inputs.
- If eagle_offline=True, aux_hidden_states and hidden_states must be provided; currently missing keys lead to NoneType errors later.
- return_eagle_inputs should be rejected in offline mode to avoid ambiguous semantics.
1224-1226
: PP early return should be outside the online/offline branching.The pipeline parallel early return (
if not self.post_process
) is only in the online path. This prevents offline mode from working correctly with pipeline parallelism.Move the PP check outside of the conditional to handle both paths:
return_eagle_inputs=return_eagle_inputs, ) - # Typically, this is only the case when PP > 1. - if not self.post_process: - return hidden_states + # Typically, this is only the case when PP > 1. + if not self.post_process: + return hidden_states
🧹 Nitpick comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
1344-1344
: Inconsistent indexing for eagle_logits extraction in offline mode.The code uses
logits_sbh.shape[0]
for eagle_logits slicing which differs between online and offline modes. Using negative indexing (-labels.shape[1]:
) would be more consistent and less error-prone.Apply consistent negative indexing for all eagle_logits extractions:
- eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :] + eagle_logits_1 = eagle_logits_2x[-labels.shape[1]:, :, :] - eagle_logits_2 = eagle_logits_3x[-labels.shape[1] :, :, :] + eagle_logits_2 = eagle_logits_3x[-labels.shape[1]:, :, :] - eagle_logits_3 = eagle_logits_4x[-labels.shape[1] :, :, :] + eagle_logits_3 = eagle_logits_4x[-labels.shape[1]:, :, :]Also applies to: 1386-1386, 1428-1428
1282-1292
: Replace zero-padding with an ignore-index (or warn) when labels are 1 token shortlabels are padded with 0 when labels.shape[1] == input_ids.shape[1] - 1 (modelopt/torch/speculative/plugins/megatron_eagle.py:1282-1292), which will make the last token contribute to loss if logit distillation is disabled. eagle_self_logit_distillation exists in the codebase (passed through config and asserted in the plugin), but the padding branch still runs in some cases — avoid training corruption by masking instead of zero-padding.
- Recommended fix (preferred): pad with the loss ignore index (e.g. -100) or the project’s IGNORE_TOKEN_ID:
right_token_pad = torch.full((labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device)
labels = torch.cat((labels, right_token_pad), dim=-1)- Alternate (minimum): log a warning when padding happens and logit distillation is disabled:
if not self.eagle_self_logit_distillation:
import warnings; warnings.warn("Label padding in offline mode may introduce training errors when logit distillation is disabled.")- Location: modelopt/torch/speculative/plugins/megatron_eagle.py lines ~1282-1292 (padding) and ~736-761 (eagle_self_logit_distillation use/assert).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/speculative/eagle/conversion.py
(1 hunks)modelopt/torch/speculative/plugins/megatron_eagle.py
(9 hunks)modelopt/torch/speculative/plugins/transformers.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/speculative/plugins/transformers.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/speculative/eagle/conversion.py (2)
modelopt/torch/opt/dynamic.py (4)
original_cls
(867-873)register
(1069-1096)convert
(613-672)convert
(1117-1119)modelopt/torch/speculative/mode.py (2)
convert
(51-53)convert
(79-81)
modelopt/torch/speculative/plugins/megatron_eagle.py (3)
modelopt/torch/opt/dynamic.py (2)
config
(1265-1278)get
(1013-1038)modelopt/torch/speculative/utils.py (1)
get_default_attention_mask_and_position_ids
(48-65)modelopt/torch/speculative/plugins/transformers.py (1)
_base_model_forward
(680-718)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/speculative/eagle/conversion.py (2)
34-39
: LGTM! Dynamic registry registration handles offline mode elegantly.The subclass detection and registration logic correctly enables dynamic modules that aren't explicitly registered to work with the Eagle registry, which is essential for handling offline mode models.
41-51
: All Eagle config fields are handled — no change required.Base modify (modelopt/torch/speculative/eagle/eagle_model.py:29–51) assigns the passed params; plugins/transformers.py (371–418) and plugins/megatron_eagle.py (712–806, and usages throughout) consume and validate them (e.g., eagle_architecture_config → PretrainedConfig/from_dict and parallel/aux checks; eagle_freeze_base_model freezes params; eagle_loss_decay_factor used in loss scaling; eagle_self_logit_distillation validated against draft_vocab_size).
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
846-865
: LGTM! Clear separation of online/offline logic in _get_eagle_input_hidden_states.The method correctly handles both online and offline modes, with appropriate FC application based on the eagle_offline flag and aux_hidden_states availability.
eagle_config.use_mtp_layernorm = self.eagle_config.use_mtp_layernorm | ||
eagle_config.draft_vocab_size = self.eagle_config.draft_vocab_size | ||
eagle_config.has_lm_head = self.eagle_config.has_lm_head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Ensure draft_vocab_size and has_lm_head are properly initialized.
The code passes draft_vocab_size
and has_lm_head
to the eagle_config, but these fields need to exist in the eagle_architecture_config.
🏁 Script executed:
#!/bin/bash
# Verify that draft_vocab_size and has_lm_head are present in eagle_architecture_config
rg -n "eagle_architecture_config.*draft_vocab_size|eagle_architecture_config.*has_lm_head" --type py
# Check for default values or initialization of these fields
ast-grep --pattern 'eagle_architecture_config.get("draft_vocab_size"$$$)'
ast-grep --pattern 'eagle_architecture_config.get("has_lm_head"$$$)'
Length of output: 329
Initialize draft_vocab_size and has_lm_head in eagle_architecture_config
modelopt/torch/speculative/plugins/megatron_eagle.py:817–819 assigns eagle_config.draft_vocab_size and eagle_config.has_lm_head, but those keys only appear in tests (tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:81,231). Add these fields (with sensible defaults) to the eagle_architecture_config definition or validate/guard before assigning to avoid runtime errors.
🤖 Prompt for AI Agents
In modelopt/torch/speculative/plugins/megatron_eagle.py around lines 817-819,
the code unconditionally assigns eagle_config.draft_vocab_size and
eagle_config.has_lm_head but those fields are not defined on
eagle_architecture_config and only referenced in tests; either add
draft_vocab_size (default 0 or None) and has_lm_head (default False) to the
eagle_architecture_config dataclass/definition where the config is declared, or
wrap these assignments with guards like "if hasattr(self.eagle_config,
'draft_vocab_size')/hasattr(self.eagle_config, 'has_lm_head')" before assigning
to avoid AttributeError at runtime—choose adding fields if they are intrinsic to
the model, otherwise guard the assignments and document the defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM;
What does this PR do?
Type of change: refactor
Overview:
Following HF implementation, remove the dedicate Detached Eagle class and handle offline case in the same class as online mode.
Usage
# Add a code snippet demonstrating how to use this
Testing
Tested with both online and offline Qwen3 30B for EAGLE-3 training.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor